Dataset <- R6::R6Class(
  classname = "dataset", 
  lock_objects = FALSE,
  public = list(
    .getitem = function(index) {
      not_implemented_error()
    }
  )
)

is_map_dataset <- function(x) {
  inherits(x, "dataset")
}

get_init <- function(x) {
  
  if (!is.null(x$public_methods$initialize))
    return(x$public_methods$initialize)
  else
    return(get_init(x$get_inherit()))
  
}

#' Helper function to create an R6 class that inherits from the abstract `Dataset` class
#' 
#' All datasets that represent a map from keys to data samples should subclass this 
#' class. All subclasses should overwrite the `.getitem()` method, which supports 
#' fetching a data sample for a given key. Subclasses could also optionally 
#' overwrite `.length()`, which is expected to return the size of the dataset 
#' (e.g. number of samples) used by many sampler implementations 
#' and the default options of [dataloader()].
#' 
#' @section Get a batch of observations:
#' 
#' By default datasets are iterated by returning each observation/item individually.
#' Sometimes it's possible to have an optimized implementation to take a batch
#' of observations (eg, subsetting a tensor by multiple indexes at once is faster than
#' subsetting once for each index), in this case you can implement a `.getbatch` method
#' that will be used instead of `.getitem` when getting a batch of observations within
#' the dataloader.
#' 
#' @note 
#' [dataloader()]  by default constructs a index
#' sampler that yields integral indices.  To make it work with a map-style
#' dataset with non-integral indices/keys, a custom sampler must be provided.
#' 
#' @param name a name for the dataset. It it's also used as the class
#'   for it.
#' @param inherit you can optionally inherit from a dataset when creating a 
#'   new dataset.
#' @param ... public methods for the dataset class
#' @param parent_env An environment to use as the parent of newly-created 
#'   objects.
#' @inheritParams nn_module
#' 
#' @export
dataset <- function(name = NULL, inherit = Dataset, ..., 
                    private = NULL, active = NULL,
                    parent_env = parent.frame()) {
  create_class(
    name = name, 
    inherit = inherit,
    ...,
    private = private, 
    active = active,
    parent_env = parent_env,
    attr_name = "Dataset"
  )
}

#' @export
`[.dataset` <- function(x, y) {
  if (length(y) > 1 && !is.null(x$.getbatch))
    x$.getbatch(y)
  else
    x$.getitem(y)
}

#' @export
length.dataset <- function(x) {
  x$.length()
}

#' Dataset wrapping tensors.
#' 
#' Each sample will be retrieved by indexing tensors along the first dimension.
#' 
#' @param ... tensors that have the same size of the first dimension.
#'
#' @export
tensor_dataset <- dataset(
  name = "tensor_dataset",
  initialize = function(...) {
    tensors <- rlang::list2(...)
    lens <- sapply(tensors, function(x) x$shape[1])
    
    if (!length(unique(lens)))
      value_error("all tensors must have the same size in the first dimension.")
    
    self$tensors <- tensors
  },
  .getitem = function(index) {
    
    if (is.list(index)) {
      index <- unlist(index)
    }
    
    lapply(self$tensors, function(x) {
        x[index, ..]
    })
  },
  .getbatch = function(index) {
    self$.getitem(index)
  },
  .length = function() {
    self$tensors[[1]]$shape[1]
  }
)

#' Dataset Subset
#'
#' Subset of a dataset at specified indices.
#'
#' @param dataset  (Dataset): The whole Dataset
#' @param indices  (sequence): Indices in the whole set selected for subset
#'
#' @export
dataset_subset <- dataset(
  initialize = function(dataset, indices) {
    self$dataset = dataset
    self$indices = indices
  },
  
  .getitem = function(idx) {
    return(self$dataset[self$indices[idx]])
  },
  
  .length = function() {
    return(length(self$indices))
  }
)






